{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Datasets Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import asyncio\n",
    "import logging\n",
    "import requests\n",
    "from dataclasses import dataclass, field\n",
    "from pathlib import Path\n",
    "from typing import Iterable, Tuple, Union\n",
    "\n",
    "import aiohttp\n",
    "from tqdm import tqdm\n",
    "\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def extract_file(filepath, directory):\n",
    "    filepath = Path(filepath)\n",
    "    if '.zip' in filepath.suffix:\n",
    "        import zipfile\n",
    "        logger.info('Decompressing zip file...')\n",
    "        with zipfile.ZipFile(filepath, 'r') as zip_ref:\n",
    "            zip_ref.extractall(directory)\n",
    "    else:\n",
    "        from patoolib import extract_archive\n",
    "        extract_archive(filepath, outdir=directory)\n",
    "    logger.info(f'Successfully decompressed {filepath}')\n",
    "\n",
    "def download_file(directory: str, source_url: str, decompress: bool = False) -> None:\n",
    "    \"\"\"Download data from source_ulr inside directory.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    directory: str, Path\n",
    "        Custom directory where data will be downloaded.\n",
    "    source_url: str\n",
    "        URL where data is hosted.\n",
    "    decompress: bool\n",
    "        Wheter decompress downloaded file. Default False.\n",
    "    \"\"\"\n",
    "    if isinstance(directory, str):\n",
    "        directory = Path(directory)\n",
    "    directory.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    filename = Path(source_url.split('/')[-1])\n",
    "\n",
    "    # On windows file must have only zip in suffix\n",
    "    if '.zip' in filename.suffix:\n",
    "        filename = Path(filename).stem + \".zip\"\n",
    "\n",
    "    filepath = Path(f'{directory}/{filename}')\n",
    "\n",
    "    # Streaming, so we can iterate over the response.\n",
    "    headers = {'User-Agent': 'Mozilla/5.0'}\n",
    "    r = requests.get(source_url, stream=True, headers=headers)\n",
    "    # Total size in bytes.\n",
    "    total_size = int(r.headers.get('content-length', 0))\n",
    "    block_size = 1024 #1 Kibibyte\n",
    "\n",
    "    t = tqdm(total=total_size, unit='iB', unit_scale=True)\n",
    "    with open(filepath, 'wb') as f:\n",
    "        for data in r.iter_content(block_size):\n",
    "            t.update(len(data))\n",
    "            f.write(data)\n",
    "            f.flush()\n",
    "    t.close()\n",
    "\n",
    "    if total_size != 0 and t.n != total_size:\n",
    "        logger.error('ERROR, something went wrong downloading data')\n",
    "\n",
    "    size = filepath.stat().st_size\n",
    "    logger.info(f'Successfully downloaded {filename}, {size}, bytes.')\n",
    "\n",
    "    if decompress:\n",
    "        extract_file(filepath, directory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "async def _async_download_file(session: aiohttp.ClientSession, path: Path, source_url: str):\n",
    "    async with session.get(source_url) as response:\n",
    "        content = await response.read()\n",
    "    fname = source_url.split('/')[-1]\n",
    "    (path / fname).write_bytes(content)\n",
    "    return fname"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "async def async_download_files(path: Union[str, Path], urls: Iterable[str]):\n",
    "    path = Path(path)\n",
    "    path.mkdir(exist_ok=True, parents=True)\n",
    "    tasks = []\n",
    "    async with aiohttp.ClientSession() as session:\n",
    "        for url in urls:\n",
    "            tasks.append(_async_download_file(session, path, url))\n",
    "        for task in asyncio.as_completed(tasks):\n",
    "            fname = await task\n",
    "            logger.info(f'Downloaded: {fname}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tempfile\n",
    "\n",
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gh_url = 'https://api.github.com/repos/Nixtla/datasetsforecast/git/trees/main'\n",
    "base_url = 'https://raw.githubusercontent.com/Nixtla/datasetsforecast/main'\n",
    "\n",
    "resp = requests.get(gh_url)\n",
    "urls = []\n",
    "for e in resp.json()['tree']:\n",
    "    if e['type'] == 'blob':\n",
    "        urls.append(f'{base_url}/{e[\"path\"]}')\n",
    "with tempfile.TemporaryDirectory() as tmp:\n",
    "    tmp = Path(tmp)\n",
    "    await async_download_files(tmp, urls)\n",
    "    files = list(tmp.iterdir())\n",
    "    assert len(files) == len(urls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export        \n",
    "def download_files(directory: Union[str, Path], urls: Iterable[str]):\n",
    "    loop = asyncio.get_event_loop()\n",
    "    if loop.is_running():\n",
    "        raise Exception(\n",
    "            \"Can't use this function when there's already a running loop. \"\n",
    "            \"Use `await async_download_files(...) instead.`\"\n",
    "        )\n",
    "    asyncio.run(async_download_files(directory, urls))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tempfile.TemporaryDirectory() as tmp:\n",
    "    tmp = Path(tmp)\n",
    "    fname = tmp / 'script.py'\n",
    "    fname.write_text(f\"\"\"\n",
    "from datasetsforecast.utils import download_files\n",
    "    \n",
    "download_files('{tmp.as_posix()}', {urls})\n",
    "    \"\"\")\n",
    "    !python {fname}\n",
    "    fname.unlink()\n",
    "    files = list(tmp.iterdir())\n",
    "    assert len(files) == len(urls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dataclass\n",
    "class Info:\n",
    "    \"\"\"\n",
    "    Info Dataclass of datasets.\n",
    "    Args:\n",
    "        groups (Tuple): Tuple of str groups\n",
    "        class_groups (Tuple): Tuple of dataclasses.\n",
    "    \"\"\"\n",
    "    class_groups: Tuple[dataclass] \n",
    "    groups: Tuple[str] = field(init=False)\n",
    "    \n",
    "    def __post_init__(self):\n",
    "        self.groups = tuple(cls_.__name__ for cls_ in self.class_groups)\n",
    "\n",
    "    def get_group(self, group: str):\n",
    "        \"\"\"Gets dataclass of group.\"\"\"\n",
    "        if group not in self.groups:\n",
    "            raise Exception(f'Unknown group {group}')\n",
    "        return self.class_groups[self.groups.index(group)]\n",
    "    \n",
    "    def __getitem__(self, group: str):\n",
    "        \"\"\"Gets dataclass of group.\"\"\"\n",
    "        if group not in self.groups:\n",
    "            raise Exception(f'Unknown group {group}')\n",
    "        return self.class_groups[self.groups.index(group)]\n",
    "    \n",
    "    def __iter__(self):\n",
    "        for group in self.groups:\n",
    "            yield group, self.get_group(group)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
